import torch.nn as nn

class ResNet9(nn.Module):
    
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.llBN = nn.BatchNorm2d(512)
        
        def conv_block(in_channels, out_channels, pool=False):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                      nn.BatchNorm2d(out_channels),
                      nn.ReLU(inplace=True)]
            if pool: layers.append(nn.MaxPool2d(2))
            return nn.Sequential(*layers)

        def conv_block_BN(in_channels, out_channels, pool=False):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                      self.llBN,
                      nn.ReLU(inplace=True)]
            if pool: layers.append(nn.MaxPool2d(2))
            return nn.Sequential(*layers)


        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))

        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block_BN(512, 512))

        self.classifier = nn.Sequential(nn.MaxPool2d(4),
                                        nn.Flatten())
        self.fc = nn.Linear(512, num_classes)

    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        out = self.fc(out)
        return out